{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于MindSpore通过GPT实现情感分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%capture captured_output\n",
    "# 实验环境已经预装了mindspore==2.3.0，如需更换mindspore版本，可更改下面 MINDSPORE_VERSION 变量\n",
    "!pip uninstall mindspore -y\n",
    "!export MINDSPORE_VERSION=2.3.0\n",
    "!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.mirrors.ustc.edu.cn/simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: mindnlp in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (0.3.1)\n",
      "Requirement already satisfied: mindspore in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.3.0)\n",
      "Requirement already satisfied: tqdm in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (4.66.4)\n",
      "Requirement already satisfied: requests in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.32.3)\n",
      "Requirement already satisfied: datasets in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.20.0)\n",
      "Requirement already satisfied: evaluate in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.4.2)\n",
      "Requirement already satisfied: tokenizers in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.19.1)\n",
      "Requirement already satisfied: safetensors in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.4.3)\n",
      "Requirement already satisfied: sentencepiece in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.2.0)\n",
      "Requirement already satisfied: regex in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2024.7.24)\n",
      "Requirement already satisfied: addict in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.4.0)\n",
      "Requirement already satisfied: ml-dtypes in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.4.0)\n",
      "Requirement already satisfied: pyctcdecode in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.5.0)\n",
      "Requirement already satisfied: jieba in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.42.1)\n",
      "Requirement already satisfied: pytest==7.2.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (7.2.0)\n",
      "Requirement already satisfied: attrs>=19.2.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2.0)\n",
      "Requirement already satisfied: iniconfig in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.0)\n",
      "Requirement already satisfied: packaging in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (24.1)\n",
      "Requirement already satisfied: pluggy<2.0,>=0.12 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.5.0)\n",
      "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.2.0)\n",
      "Requirement already satisfied: tomli>=1.0.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.1)\n",
      "Requirement already satisfied: filelock in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.15.4)\n",
      "Requirement already satisfied: numpy>=1.17 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (1.26.4)\n",
      "Requirement already satisfied: pyarrow>=15.0.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (17.0.0)\n",
      "Requirement already satisfied: pyarrow-hotfix in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.6)\n",
      "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.3.8)\n",
      "Requirement already satisfied: pandas in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (2.2.2)\n",
      "Requirement already satisfied: xxhash in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.4.1)\n",
      "Requirement already satisfied: multiprocess in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.70.16)\n",
      "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets->mindnlp) (2024.5.0)\n",
      "Requirement already satisfied: aiohttp in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.10.0)\n",
      "Requirement already satisfied: huggingface-hub>=0.21.2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.24.3)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (6.0.1)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.7)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2.2.2)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2024.7.4)\n",
      "Requirement already satisfied: protobuf>=3.13.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.27.2)\n",
      "Requirement already satisfied: asttokens>=2.0.4 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (2.0.5)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (10.4.0)\n",
      "Requirement already satisfied: scipy>=1.5.4 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.13.1)\n",
      "Requirement already satisfied: psutil>=5.6.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.9.0)\n",
      "Requirement already satisfied: astunparse>=1.6.3 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.6.3)\n",
      "Requirement already satisfied: pygtrie<3.0,>=2.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pyctcdecode->mindnlp) (2.5.0)\n",
      "Requirement already satisfied: hypothesis<7,>=6.14 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pyctcdecode->mindnlp) (6.108.5)\n",
      "Requirement already satisfied: six in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore->mindnlp) (1.16.0)\n",
      "Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore->mindnlp) (0.43.0)\n",
      "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (2.3.2)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (1.3.1)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (1.4.1)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (6.0.5)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (1.9.4)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp) (4.0.3)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub>=0.21.2->datasets->mindnlp) (4.11.0)\n",
      "Requirement already satisfied: sortedcontainers<3.0.0,>=2.1.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from hypothesis<7,>=6.14->pyctcdecode->mindnlp) (2.4.0)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)\n",
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: jieba in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (0.42.1)\n",
      "env: HF_ENDPOINT=https://hf-mirror.com\n"
     ]
    }
   ],
   "source": [
    "# 该案例在 mindnlp 0.3.1 版本完成适配，如果发现案例跑不通，可以指定mindnlp版本，执行`!pip install mindnlp==0.3.1`\n",
    "!pip install mindnlp\n",
    "!pip install jieba\n",
    "%env HF_ENDPOINT=https://hf-mirror.com"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 1.081 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import mindspore\n",
    "from mindspore.dataset import text, GeneratorDataset, transforms\n",
    "from mindspore import nn\n",
    "\n",
    "from mindnlp.dataset import load_dataset\n",
    "\n",
    "from mindnlp._legacy.engine import Trainer, Evaluator\n",
    "from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback\n",
    "from mindnlp._legacy.metrics import Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fe304e05e37843cdbfd4f2f8b15bdd6d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme: 0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ec4f7b237ff441d0b79a2eb23459bc74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e183032ddae480ca37f7cea245188c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/20.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fb03161a78104470ac72ce879ecb9319",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/42.0M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa928c54969d4bf0a210265663d67de0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d991b6c4f23d42e69f4fcb3150b806cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "784bdb842786421fa00b83aa9851043a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "imdb_ds = load_dataset('imdb', split=['train', 'test'])\n",
    "imdb_train = imdb_ds['train']\n",
    "imdb_test = imdb_ds['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "25000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imdb_train.get_dataset_size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):\n",
    "    is_ascend = mindspore.get_context('device_target') == 'Ascend'\n",
    "    def tokenize(text):\n",
    "        if is_ascend:\n",
    "            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)\n",
    "        else:\n",
    "            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)\n",
    "        return tokenized['input_ids'], tokenized['attention_mask']\n",
    "\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(batch_size)\n",
    "\n",
    "    # map dataset\n",
    "    dataset = dataset.map(operations=[tokenize], input_columns=\"text\", output_columns=['input_ids', 'attention_mask'])\n",
    "    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns=\"label\", output_columns=\"labels\")\n",
    "    # batch dataset\n",
    "    if is_ascend:\n",
    "        dataset = dataset.batch(batch_size)\n",
    "    else:\n",
    "        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                             'attention_mask': (None, 0)})\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d52c1144dd443f0beb386da45155e07",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0.00/25.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "835c6ac53a61446fbb2e6680faf5a1d0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "79ffc167a24c47928d54688d2643dd8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b68c6d96099412387647b78c1d1891d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5d74ac44c254dffb17b71e71fa65cb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0.00/367 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import GPTTokenizer\n",
    "# tokenizer\n",
    "gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')\n",
    "\n",
    "# add sepcial token: <PAD>\n",
    "special_tokens_dict = {\n",
    "    \"bos_token\": \"<bos>\",\n",
    "    \"eos_token\": \"<eos>\",\n",
    "    \"pad_token\": \"<pad>\",\n",
    "}\n",
    "num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# split train dataset into train and valid datasets\n",
    "imdb_train, imdb_val = imdb_train.split([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)\n",
    "dataset_val = process_dataset(imdb_val, gpt_tokenizer)\n",
    "dataset_test = process_dataset(imdb_test, gpt_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Tensor(shape=[4, 512], dtype=Int64, value=\n",
       " [[  500,   616,   850 ... 40480, 40480, 40480],\n",
       "  [  616,  4121,  4482 ... 40480, 40480, 40480],\n",
       "  [  249, 12266,   616 ... 40480, 40480, 40480],\n",
       "  [  606,  2310,   256 ... 40480, 40480, 40480]]),\n",
       " Tensor(shape=[4, 512], dtype=Int64, value=\n",
       " [[1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0]]),\n",
       " Tensor(shape=[4], dtype=Int32, value= [0, 0, 0, 0])]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(dataset_train.create_tuple_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "caed57482f2d468d8563b545f6eab76a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0.00/457M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71e49ea6a74d43a49d65f2d17cdc03de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0.00/74.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following parameters in models are missing parameter:\n",
      "['score.weight']\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import GPTForSequenceClassification\n",
    "from mindspore.experimental.optim import Adam\n",
    "\n",
    "# set bert config and define parameters for training\n",
    "model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)\n",
    "model.config.pad_token_id = gpt_tokenizer.pad_token_id\n",
    "model.resize_token_embeddings(model.config.vocab_size + 3)\n",
    "\n",
    "optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)\n",
    "\n",
    "metric = Accuracy()\n",
    "\n",
    "# define callbacks to save checkpoints\n",
    "ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)\n",
    "best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)\n",
    "\n",
    "trainer = Trainer(network=model, train_dataset=dataset_train,\n",
    "                  eval_dataset=dataset_train, metrics=metric,\n",
    "                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],\n",
    "                  jit=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The train will start from the checkpoint saved in 'checkpoint'.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "38c97d098dca4a6db887384bf1920cfd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4375 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: 'gpt_imdb_finetune_epoch_0.ckpt' has been saved in epoch: 0.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f80707bf4b114939a5a1aa2b6d7f9f20",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4375 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.9301714285714285}\n",
      "---------------Best Model: 'gpt_imdb_finetune_best.ckpt' has been saved in epoch: 0.---------------\n",
      "Loading best model from 'checkpoint' with '['Accuracy']': [0.9301714285714285]...\n",
      "---------------The model is already load the best model from 'gpt_imdb_finetune_best.ckpt'.---------------\n",
      "CPU times: user 2h 14min 57s, sys: 27min 16s, total: 2h 42min 13s\n",
      "Wall time: 36min 55s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "trainer.run(tgt_columns=\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9737816b98044be9452f72d67626350",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/6250 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.89952}\n",
      "CPU times: user 36min 22s, sys: 26.6 s, total: 36min 49s\n",
      "Wall time: 9min 17s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)\n",
    "evaluator.run(tgt_columns=\"labels\")"
   ]
  }
 ],
 "metadata": {
  "AIGalleryInfo": {
   "item_id": "1325b719-fc78-46c4-8f47-9f3623e9b0f4"
  },
  "flavorInfo": {
   "architecture": "X86_64",
   "category": "GPU"
  },
  "imageInfo": {
   "id": "e1a07296-22a8-4f05-8bc8-e936c8e54202",
   "name": "mindspore1.7.0-cuda10.1-py3.7-ubuntu18.04"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
